{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!git clone https://github.com/sithu31296/semantic-segmentation\n",
    "%cd semantic-segmentation\n",
    "%pip install -e ."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torchvision import io\n",
    "from torchvision import transforms as T\n",
    "from PIL import Image\n",
    "\n",
    "def show_image(image):\n",
    "    if image.shape[2] != 3: image = image.permute(1, 2, 0)\n",
    "    image = Image.fromarray(image.numpy())\n",
    "    return image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Show Available Pretrained Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from semseg import show_models\n",
    "\n",
    "show_models()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load a Pretrained Model\n",
    "\n",
    "Download a pretrained model's weights from the result table (ADE20K, CityScapes, ...) and put it in `checkpoints/pretrained/model_name/`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install -U gdown"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gdown\n",
    "from pathlib import Path\n",
    "\n",
    "ckpt = Path('./checkpoints/pretrained/segformer')\n",
    "ckpt.mkdir(exist_ok=True, parents=True)\n",
    "\n",
    "url = 'https://drive.google.com/uc?id=1-OmW3xRD3WAbJTzktPC-VMOF5WMsN8XT'\n",
    "output = './checkpoints/pretrained/segformer/segformer.b3.ade.pth'\n",
    "\n",
    "gdown.download(url, output, quiet=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from semseg.models import *\n",
    "\n",
    "model = eval('SegFormer')(\n",
    "    backbone='MiT-B3',\n",
    "    num_classes=150\n",
    ")\n",
    "\n",
    "try:\n",
    "    model.load_state_dict(torch.load('checkpoints/pretrained/segformer/segformer.b3.ade.pth', map_location='cpu'))\n",
    "except:\n",
    "    print(\"Download a pretrained model's weights from the result table.\")\n",
    "model.eval()\n",
    "\n",
    "print('Loaded Model')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simple Image Inference\n",
    "\n",
    "### Load Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_path = './assests/ade/ADE_val_00000049.jpg'\n",
    "image = io.read_image(image_path)\n",
    "print(image.shape)\n",
    "show_image(image)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preprocess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# resize\n",
    "image = T.CenterCrop((512, 512))(image)\n",
    "# scale to [0.0, 1.0]\n",
    "image = image.float() / 255\n",
    "# normalize\n",
    "image = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))(image)\n",
    "# add batch size\n",
    "image = image.unsqueeze(0)\n",
    "image.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model Forward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.inference_mode():\n",
    "    seg = model(image)\n",
    "seg.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Postprocess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seg = seg.softmax(1).argmax(1).to(int)\n",
    "seg.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from semseg.datasets import *\n",
    "\n",
    "palette = eval('ADE20K').PALETTE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seg_map = palette[seg].squeeze().to(torch.uint8)\n",
    "show_image(seg_map)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Show Available Backbones"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from semseg import show_backbones\n",
    "\n",
    "show_backbones()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Show Available Heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from semseg import show_heads\n",
    "\n",
    "show_heads()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Show Available Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from semseg import show_datasets\n",
    "\n",
    "show_datasets()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Construct a Custom Model\n",
    "\n",
    "### Choose a Backbone"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from semseg.models.backbones import ResNet\n",
    "\n",
    "backbone = ResNet('18')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# init random input batch\n",
    "x = torch.randn(2, 3, 224, 224)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get features from the backbone\n",
    "features = backbone(x)\n",
    "for out in features:\n",
    "    print(out.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Choose a Head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from semseg.models.heads import UPerHead\n",
    "\n",
    "head = UPerHead(backbone.channels, 128, num_classes=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seg = head(features)\n",
    "seg.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.nn import functional as F\n",
    "# upsample the output\n",
    "seg = F.interpolate(seg, size=x.shape[-2:], mode='bilinear', align_corners=False)\n",
    "seg.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check `semseg/models/custom_cnn.py` and `semseg/models/custom_vit.py` for a complete construction for custom model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
  },
  "kernelspec": {
   "display_name": "Python 3.6.9 64-bit",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
